import pytest
from pr2test.context_manager import make_test_matrix, skip_if_not_supported
from pr2test.marks import require_root
from pr2test.tools import qdisc_exists

from pyroute2 import protocols

pytestmark = [require_root()]
test_matrix = make_test_matrix(targets=['local', 'netns'])


@pytest.mark.parametrize('context', test_matrix, indirect=True)
@skip_if_not_supported
def test_htb_over_32gbit(context):

    index, ifname = context.default_interface

    # 8<-----------------------------------------------------
    # root queue, '1:0' handle notation
    context.ipr.tc('add', 'htb', index=index, handle='1:', default='20:0')

    assert qdisc_exists(context.netns, 'htb', ifname=ifname)

    # 8<-----------------------------------------------------
    # classes, both string and int handle notation
    context.ipr.tc(
        'add-class',
        'htb',
        index=index,
        handle=0x10001,
        parent=0x10000,
        rate='50gbit',
        ceil='50gbit',
    )
    context.ipr.tc(
        'add-class',
        'htb',
        index=index,
        handle=0x10010,
        parent=0x10001,
        rate='35gbit',
        ceil='35gbit',
        prio=1,
    )
    context.ipr.tc(
        'add-class',
        'htb',
        index=index,
        handle=0x10020,
        parent=0x10001,
        rate='128kbit',
        ceil='128kbit',
        prio=2,
    )

    # 8<-----------------------------------------------------
    # list the installed classes
    classes = tuple(context.ipr.get_classes(index=index))
    assert len(classes) == 3
    rate0 = classes[0].get(('TCA_OPTIONS', 'TCA_HTB_RATE64'), 0)
    ceil0 = classes[0].get(('TCA_OPTIONS', 'TCA_HTB_CEIL64'), 0)
    rate1 = classes[1].get(('TCA_OPTIONS', 'TCA_HTB_RATE64'), 0)
    ceil1 = classes[1].get(('TCA_OPTIONS', 'TCA_HTB_CEIL64'), 0)
    rate2 = classes[2].get(('TCA_OPTIONS', 'TCA_HTB_RATE64'), 0)
    ceil2 = classes[2].get(('TCA_OPTIONS', 'TCA_HTB_CEIL64'), 0)
    assert rate0 == ceil0
    assert rate1 == ceil1
    assert rate2 == ceil2
    assert rate0 > rate2
    assert rate1 > rate2
    assert rate2 == 0


@pytest.mark.parametrize('context', test_matrix, indirect=True)
@skip_if_not_supported
def test_htb(context):
    index, ifname = context.default_interface
    # 8<-----------------------------------------------------
    # root queue, '1:0' handle notation
    context.ipr.tc('add', 'htb', index=index, handle='1:', default='20:0')

    assert qdisc_exists(context.netns, 'htb', ifname=ifname)

    # 8<-----------------------------------------------------
    # classes, both string and int handle notation
    context.ipr.tc(
        'add-class',
        'htb',
        index=index,
        handle='1:1',
        parent='1:0',
        rate='256kbit',
        burst=1024 * 6,
    )
    context.ipr.tc(
        'add-class',
        'htb',
        index=index,
        handle=0x10010,
        parent=0x10001,
        rate='192kbit',
        burst=1024 * 6,
        prio=1,
    )
    context.ipr.tc(
        'add-class',
        'htb',
        index=index,
        handle='1:20',
        parent='1:1',
        rate='128kbit',
        burst=1024 * 6,
        prio=2,
    )
    cls = tuple(context.ipr.get_classes(index=index))
    assert len(cls) == 3

    # 8<-----------------------------------------------------
    # leaves, both string and int handle notation
    context.ipr.tc(
        'add', 'sfq', index=index, handle='10:', parent='1:10', perturb=10
    )
    context.ipr.tc(
        'add', 'sfq', index=index, handle=0x200000, parent=0x10020, perturb=10
    )
    qds = [x for x in context.ipr.get_qdiscs() if x['index'] == index]
    types = set([x.get_attr('TCA_KIND') for x in qds])
    assert types == set(('htb', 'sfq'))

    # 8<-----------------------------------------------------
    # filters, both string and int handle notation
    #
    # Please note, that u32 filter requires ethernet protocol
    # numbers, as defined in protocols module. Do not provide
    # here socket.AF_INET and so on.
    #
    context.ipr.tc(
        'add-filter',
        'u32',
        index=index,
        handle='0:0',
        parent='1:0',
        prio=10,
        protocol=protocols.ETH_P_IP,
        target='1:10',
        keys=['0x0006/0x00ff+8', '0x0000/0xffc0+2'],
    )
    context.ipr.tc(
        'add-filter',
        'u32',
        index=index,
        handle=0,
        parent=0x10000,
        prio=10,
        protocol=protocols.ETH_P_IP,
        target=0x10020,
        keys=['0x5/0xf+0', '0x10/0xff+33'],
    )
    # 2 filters + 2 autogenerated
    fls = tuple(context.ipr.get_filters(index=index))
    assert len(fls) == 4


@pytest.mark.parametrize('context', test_matrix, indirect=True)
@skip_if_not_supported
def test_replace(context):
    test_htb(context)
    index, ifname = context.default_interface
    # change class
    context.ipr.tc(
        'replace-class',
        'htb',
        index=index,
        handle=0x10010,
        parent=0x10001,
        rate='102kbit',
        burst=1024 * 6,
        prio=3,
    )
    clss = tuple(context.ipr.get_classes(index=index))
    for cls in clss:
        if cls['handle'] == 0x10010:
            break
    else:
        raise Exception('target class not found')
    opts = cls.get_attr('TCA_OPTIONS')
    params = opts.get_attr('TCA_HTB_PARMS')

    assert params['prio'] == 3
    assert params['quantum'] * 8 == 10200
